{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "946d5a3d",
   "metadata": {},
   "source": [
    "# Serve Multiple Fine-Tuned LoRA Adapters with DJL Serving"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35756ca1",
   "metadata": {},
   "source": [
    "This notebook will demonstrate how you can deploy multiple fine-tuned LoRA adapters with a single base model copy on SageMaker using the DJL Serving Large Model Inference DLC. LoRA (Low Rank Adapters) is a powerful technique for fine-tuning large language models. This technique significantly reduces the number of trainable parameters compared to traditional fine-tuning while achieving comparable or superior performance. You can learn more about the LoRA technique in this [paper](https://arxiv.org/abs/2106.09685).\n",
    "\n",
    "A major benefit of LoRA is that the fine-tuned adapters can easily be added to and removed from the base model, which makes switching adapters pretty cheap and viable at runtime. In this notebook we will show how you can deploy a SageMaker endpoint with a single base model and multiple LoRA adapters, and change adapters for different requests.\n",
    "\n",
    "Since LoRA adapters are much smaller than the size of a base model (can realistically be 100x-1000x smaller), we can deploy an endpoint with a single base model and multiple LoRA adapters using much less hardware than deploying an equivalent number of fully fine-tuned models.\n",
    "\n",
    "The example we will work through in this notebook is guided by the multi adapter example in HuggingFace's PEFT library: https://github.com/huggingface/peft/blob/main/examples/multi_adapter_examples/PEFT_Multi_LoRA_Inference.ipynb.\n",
    "\n",
    "This is the notebook demonstrating the practice using the built-in handler. For custom handlers, see the [custom notebook](multi_lora_adapter_inference_custom.ipynb)."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c39e3b6d",
   "metadata": {},
   "source": [
    "## Install Packages and Import Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "41a4ff2e",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install huggingface_hub sagemaker boto3 awscli --upgrade --quiet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "12eb06ae",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "from sagemaker import image_uris\n",
    "import boto3\n",
    "import os\n",
    "import time\n",
    "import json\n",
    "from pathlib import Path\n",
    "from sagemaker.utils import name_from_base\n",
    "from huggingface_hub import snapshot_download"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3d37b203",
   "metadata": {},
   "source": [
    "## Download Model Artifacts and Upload to S3"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8d1515c5",
   "metadata": {},
   "source": [
    "We will be deploying an endpoint with 2 LoRA adapters. These are the models we will be using:\n",
    "- Base Model: https://huggingface.co/huggyllama/llama-7b\n",
    "- LoRA Fine Tuned Adapter 1: https://huggingface.co/tloen/alpaca-lora-7b\n",
    "- LoRA Fine Tuned Adapter 2: https://huggingface.co/22h/cabrita-lora-v0-1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8d57e9d5",
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf lora-multi-adapter\n",
    "!mkdir -p lora-multi-adapter/adapters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "978224f5",
   "metadata": {},
   "outputs": [],
   "source": [
    "snapshot_download(\"tloen/alpaca-lora-7b\", local_dir=\"lora-multi-adapter/adapters/eng_alpaca\", local_dir_use_symlinks=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ba747f00",
   "metadata": {},
   "outputs": [],
   "source": [
    "snapshot_download(\"22h/cabrita-lora-v0-1\", local_dir=\"lora-multi-adapter/adapters/portuguese_alpaca\", local_dir_use_symlinks=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d89777de",
   "metadata": {},
   "source": [
    "## Creating Inference Handler and DJL Serving Configuration"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "89bd1577",
   "metadata": {},
   "source": [
    "The following files cover the model server configuration (`serving.properties`). Many of the built-in handlers such as vllm and LMI-Dist will automatically support adapters, which can be checked in the [backend's user guide](http://docs.djl.ai/docs/serving/serving/docs/lmi/index.html#supported-lmi-inference-libraries). You can also make a custom `model.py` by following the instructions in the [custom adapter notebook](http://docs.djl.ai/docs/demos/aws/sagemaker/large-model-inference/sample-llm/multi_lora_adapter_inference_custom.html).\n",
    "\n",
    "The core structure to cover here is the model directory. We include both the base model and LoRA adapters in the model directory like this:\n",
    "\n",
    "```\n",
    "|- model_dir\n",
    "    |- adapters/\n",
    "        |--- <adapter_1>/\n",
    "        |--- <adapter_2>/\n",
    "        |--- ...\n",
    "        |--- <adapter_n>/\n",
    "    |- serving.properties\n",
    "    |- model.py (optional)\n",
    "\n",
    "```\n",
    "\n",
    "It is also possible to have model files located in a separate s3 bucket by specifying that location using an s3 `option.model_id` in the serving.properties. In this case, the adapters directory can be located either alongside the `serving.properties` or alongside the model files in s3.\n",
    "\n",
    "Each of the adapters in the `adapters` directory contains the LoRA adapter artifacts. Typically there are two files: `adapter_model.bin` and `adapter_config.json` which are the adapter weights and adapter configuration respectively. These are typically obtained from the Peft library via the `PeftModel.save_pretrained()` method."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5e988852",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%writefile lora-multi-adapter/serving.properties\n",
    "option.model_id=huggyllama/llama-7b\n",
    "option.rolling_batch=lmi-dist\n",
    "option.tensor_parallel_degree=max\n",
    "option.enable_lora=true\n",
    "option.max_loras=10"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "88c45680",
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -f model.tar.gz\n",
    "!rm -rf lora-multi-adapter/.ipynb_checkpoints\n",
    "!tar czvf model.tar.gz -C lora-multi-adapter ."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3a6dca13",
   "metadata": {},
   "source": [
    "## Create SageMaker Model and Endpoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3cc6e59a",
   "metadata": {},
   "outputs": [],
   "source": [
    "role = sagemaker.get_execution_role()  # execution role for the endpoint\n",
    "sess = sagemaker.session.Session()  # sagemaker session for interacting with different AWS APIs\n",
    "bucket = sess.default_bucket()  # bucket to house artifacts\n",
    "model_bucket = sess.default_bucket()  # bucket to house artifacts\n",
    "s3_code_prefix = \"hf-large-model-djl/lora-multi-adapter\"  # folder within bucket where code artifact will go\n",
    "\n",
    "region = sess._region_name\n",
    "account_id = sess.account_id()\n",
    "\n",
    "s3_client = boto3.client(\"s3\")\n",
    "sm_client = boto3.client(\"sagemaker\")\n",
    "smr_client = boto3.client(\"sagemaker-runtime\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "5ab57364",
   "metadata": {},
   "outputs": [],
   "source": [
    "s3_code_artifact = sess.upload_data(\"model.tar.gz\", bucket, s3_code_prefix)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "61e67228",
   "metadata": {},
   "outputs": [],
   "source": [
    "inference_image_uri = image_uris.retrieve(\n",
    "        framework=\"djl-lmi\",\n",
    "        region=region,\n",
    "        version=\"0.31.0\"\n",
    "    )\n",
    "model_name_acc = name_from_base(f\"lora-multi-adapter\")\n",
    "\n",
    "create_model_response = sm_client.create_model(\n",
    "    ModelName=model_name_acc,\n",
    "    ExecutionRoleArn=role,\n",
    "    PrimaryContainer={\"Image\": inference_image_uri,\n",
    "                      \"ModelDataUrl\": s3_code_artifact,\n",
    "                     })\n",
    "model_arn = create_model_response[\"ModelArn\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b8699f5d",
   "metadata": {},
   "outputs": [],
   "source": [
    "endpoint_config_name = f\"{model_name_acc}-config\"\n",
    "endpoint_name = f\"{model_name_acc}-endpoint\"\n",
    "\n",
    "endpoint_config_response = sm_client.create_endpoint_config(\n",
    "    EndpointConfigName=endpoint_config_name,\n",
    "    ProductionVariants=[\n",
    "        {\n",
    "            \"VariantName\": \"variant1\",\n",
    "            \"ModelName\": model_name_acc,\n",
    "            \"InstanceType\": \"ml.g5.12xlarge\",\n",
    "            \"InitialInstanceCount\": 1,\n",
    "            \"ModelDataDownloadTimeoutInSeconds\": 1800,\n",
    "            \"ContainerStartupHealthCheckTimeoutInSeconds\": 1800,\n",
    "        },\n",
    "    ],\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1a1585b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(f\"endpoint_name: {endpoint_name}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "a346666f",
   "metadata": {},
   "outputs": [],
   "source": [
    "create_endpoint_response = sm_client.create_endpoint(\n",
    "    EndpointName=f\"{endpoint_name}\", EndpointConfigName=endpoint_config_name\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4e6ad681",
   "metadata": {},
   "outputs": [],
   "source": [
    "import time\n",
    "\n",
    "resp = sm_client.describe_endpoint(EndpointName=endpoint_name)\n",
    "status = resp[\"EndpointStatus\"]\n",
    "print(\"Status: \" + status)\n",
    "\n",
    "while status == \"Creating\":\n",
    "    time.sleep(60)\n",
    "    resp = sm_client.describe_endpoint(EndpointName=endpoint_name)\n",
    "    status = resp[\"EndpointStatus\"]\n",
    "    print(\"Status: \" + status)\n",
    "\n",
    "print(\"Arn: \" + resp[\"EndpointArn\"])\n",
    "print(\"Status: \" + status)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7680b38a",
   "metadata": {},
   "source": [
    "## Make Inference Requests"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "84948200",
   "metadata": {},
   "outputs": [],
   "source": [
    "%%time\n",
    "\n",
    "response_model = smr_client.invoke_endpoint(\n",
    "    EndpointName=endpoint_name,\n",
    "    Body=json.dumps({\"inputs\": \"Tell me about Alpacas\", \"adapters\": \"eng_alpaca\"}),\n",
    "    ContentType=\"application/json\",\n",
    ")\n",
    "\n",
    "response_model[\"Body\"].read().decode(\"utf8\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2d434c23",
   "metadata": {},
   "source": [
    "Inference Request targetting the base model without any adapters"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "cef62f3e",
   "metadata": {},
   "source": [
    "## Clean up Resources"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "baadd61d",
   "metadata": {},
   "outputs": [],
   "source": [
    "sm_client.delete_endpoint(EndpointName=endpoint_name)\n",
    "sm_client.delete_endpoint_config(EndpointConfigName=endpoint_config_name)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.14"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
